{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "95274c2d",
   "metadata": {},
   "source": [
    "# End to End Quantization Test\n",
    "This notebook tests our quantization implementation in one place by making sure our operations yield good accuracy on the end task. We don't check for exact correctness of the output since quantization will introduce errors that we do not worry about unless they cause an accuracy drop."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "id": "f18278d4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import torch\n",
    "from transformers import glue_compute_metrics\n",
    "import sklearn\n",
    "import math\n",
    "from sklearn.metrics import f1_score\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "from transformers import RobertaForSequenceClassification, AutoTokenizer\n",
    "from transformers.data.metrics import simple_accuracy\n",
    "from transformers.modeling_outputs import SequenceClassifierOutput, BaseModelOutputWithPoolingAndCrossAttentions\n",
    "import torch.nn as nn\n",
    "from torch.nn import CrossEntropyLoss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "b867574a",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Reusing dataset glue (/Users/oliver/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n",
      "Some weights of the model checkpoint at textattack/roberta-base-MRPC were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']\n",
      "- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Loading cached processed dataset at /Users/oliver/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-384ba54281f40c06.arrow\n",
      "Loading cached processed dataset at /Users/oliver/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-9ee5801daad855f5.arrow\n"
     ]
    }
   ],
   "source": [
    "from src.utils import roberta_mrpc_dataset\n",
    "dataset = roberta_mrpc_dataset()\n",
    "\n",
    "dataset = dataset.map(encode, batched=True)\n",
    "dataset = dataset.map(lambda examples: {'labels': examples['label']}, batched=True)\n",
    "dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "id": "bfd0b06a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def attention(layer, hidden_states, attention_mask=None):\n",
    "    '''\n",
    "    Pass in a encoder layer (which holds pretrained weights) and hidden_states input,\n",
    "    and this function performs the same operations as the layer but in a readable fashion.\n",
    "    \n",
    "    hidden_states: <bs, seqlen, dmodel>\n",
    "    '''\n",
    "    bs, seqlen, dmodel = hidden_states.size()\n",
    "    num_heads = layer.attention.self.num_attention_heads\n",
    "    dhead = layer.attention.self.attention_head_size\n",
    "    \n",
    "    # Linear transform to get multiple heads. This is a major MAC slurper.\n",
    "    # Each of these is calling an nn.Linear layer on hidden_states.\n",
    "#     query_layer = layer.attention.self.query(hidden_states) # <bs, seqlen, dmodel>\n",
    "    query_layer = torch.matmul(hidden_states, layer.attention.self.query.weight.T) + layer.attention.self.query.bias\n",
    "    key_layer = layer.attention.self.key(hidden_states)     # <bs, seqlen, dmodel>\n",
    "    value_layer = layer.attention.self.value(hidden_states) # <bs, seqlen, dmodel>\n",
    "    \n",
    "    # Reshape and transpose for multi-head\n",
    "    new_shape = (bs, seqlen, num_heads, dhead)\n",
    "    \n",
    "    query_layer = query_layer.view(new_shape)\n",
    "    value_layer = value_layer.view(new_shape)\n",
    "    key_layer = key_layer.view(new_shape)\n",
    "    \n",
    "    query_layer = query_layer.permute(0,2,1,3) # <bs, num_head, seqlen, dhead>\n",
    "    value_layer = value_layer.permute(0,2,1,3) # <bs, num_head, seqlen, dhead>\n",
    "    # Key is transposed to match dimensions of Query for matmul\n",
    "    key_layer = key_layer.permute(0,2,3,1)     # <bs, num_head, dhead, seqlen>\n",
    "    \n",
    "    # The attention main course\n",
    "    attention_scores = torch.matmul(query_layer, key_layer)\n",
    "    attention_scores /= math.sqrt(dhead)\n",
    "    \n",
    "    if attention_mask is not None:\n",
    "        # Apply the attention mask is (precomputed for all layers in RobertaModel forward() function)\n",
    "        attention_scores = attention_scores + attention_mask\n",
    "    \n",
    "    attention_probs = nn.functional.softmax(attention_scores, dim=-1)\n",
    "    attention_probs = layer.attention.self.dropout(attention_probs)\n",
    "    \n",
    "    # Weighted sum of Values from softmax attention\n",
    "    attention_out = torch.matmul(attention_probs, value_layer)\n",
    "    \n",
    "    attention_out = attention_out.permute(0,2,1,3).contiguous()\n",
    "    attention_out = attention_out.view(bs, seqlen, dmodel)\n",
    "    \n",
    "    # It's time for one more linear transform and layer norm\n",
    "    dense_out = layer.attention.output.dense(attention_out)\n",
    "    dense_out = layer.attention.output.dropout(dense_out)\n",
    "    \n",
    "    # LayerNorm also mplements the residual connection\n",
    "    layer_out = layer.attention.output.LayerNorm(dense_out + hidden_states)\n",
    "    \n",
    "    return layer_out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "id": "75352287",
   "metadata": {},
   "outputs": [],
   "source": [
    "def ffn(layer, attention_out):\n",
    "    '''\n",
    "    Pass in the feedforward layer and attention output. Returns the same result of a FF forward pass.\n",
    "    '''\n",
    "    # Layer 1\n",
    "    output = layer.intermediate.dense(attention_out)\n",
    "    output = nn.functional.gelu(output)\n",
    "    \n",
    "    # Layer 2\n",
    "    output = layer.output.dense(output)\n",
    "    output = layer.output.dropout(output)\n",
    "    output = layer.output.LayerNorm(output + attention_out)\n",
    "    \n",
    "    return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "id": "c280f259",
   "metadata": {},
   "outputs": [],
   "source": [
    "def encoder_stack(model, hidden_states, attention_mask):\n",
    "    for layer_module in model.roberta.encoder.layer:\n",
    "        # MHA + LayerNorm\n",
    "        attention_out = attention(layer_module, hidden_states, attention_mask)\n",
    "        ff_out = ffn(layer_module, attention_out)\n",
    "        hidden_states = ff_out\n",
    "    sequence_output = hidden_states\n",
    "    pooled_output = model.roberta.pooler(hidden_states) if model.roberta.pooler is not None else None\n",
    "    \n",
    "    return BaseModelOutputWithPoolingAndCrossAttentions(\n",
    "            last_hidden_state=sequence_output,\n",
    "            pooler_output=pooled_output,\n",
    "            past_key_values=None,\n",
    "            hidden_states=None,\n",
    "            attentions=None,\n",
    "            cross_attentions=None,\n",
    "        )\n",
    "        \n",
    "def sequence_classification(model,\n",
    "                            outputs, \n",
    "                            input_ids=None,\n",
    "                            attention_mask=None,\n",
    "                            token_type_ids=None,\n",
    "                            position_ids=None,\n",
    "                            head_mask=None,\n",
    "                            inputs_embeds=None,\n",
    "                            labels=None,\n",
    "                            output_attentions=None,\n",
    "                            output_hidden_states=None,\n",
    "                            return_dict=None,):\n",
    "    \n",
    "    sequence_output = outputs[0]\n",
    "    logits = model.classifier(sequence_output)\n",
    "\n",
    "    loss = None\n",
    "    if labels is not None:\n",
    "        if model.config.problem_type is None:\n",
    "            if model.num_labels == 1:\n",
    "                model.config.problem_type = \"regression\"\n",
    "            elif model.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):\n",
    "                model.config.problem_type = \"single_label_classification\"\n",
    "            else:\n",
    "                model.config.problem_type = \"multi_label_classification\"\n",
    "\n",
    "        if model.config.problem_type == \"regression\":\n",
    "            loss_fct = MSELoss()\n",
    "            if self.num_labels == 1:\n",
    "                loss = loss_fct(logits.squeeze(), labels.squeeze())\n",
    "            else:\n",
    "                loss = loss_fct(logits, labels)\n",
    "        elif model.config.problem_type == \"single_label_classification\":\n",
    "            loss_fct = CrossEntropyLoss()\n",
    "            loss = loss_fct(logits.view(-1, model.num_labels), labels.view(-1))\n",
    "        elif model.config.problem_type == \"multi_label_classification\":\n",
    "            loss_fct = BCEWithLogitsLoss()\n",
    "            loss = loss_fct(logits, labels)\n",
    "\n",
    "    if not return_dict:\n",
    "        output = (logits,) + outputs[2:]\n",
    "        return ((loss,) + output) if loss is not None else output\n",
    "\n",
    "    return SequenceClassifierOutput(\n",
    "        loss=loss,\n",
    "        logits=logits,\n",
    "        hidden_states=outputs.hidden_states,\n",
    "        attentions=outputs.attentions,\n",
    "    )\n",
    "\n",
    "def eval_model(model, dataloader):\n",
    "    device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "    model.eval()\n",
    "    preds = None\n",
    "\n",
    "    for i, batch in enumerate(tqdm(dataloader)):\n",
    "        batch = {k: v.to(device) for k, v in batch.items()}\n",
    "\n",
    "        with torch.no_grad():\n",
    "            embedding_output = model.roberta.embeddings(\n",
    "                input_ids=batch['input_ids'],\n",
    "                position_ids=None,\n",
    "                token_type_ids=None,\n",
    "                inputs_embeds=None,\n",
    "                past_key_values_length=0,\n",
    "            )\n",
    "            \n",
    "            extended_attention_mask = model.roberta.get_extended_attention_mask(batch['attention_mask'], batch['input_ids'].size(), device)\n",
    "            outputs = encoder_stack(model, embedding_output, extended_attention_mask)\n",
    "            outputs = sequence_classification(model, outputs, **batch)\n",
    "\n",
    "            outputs_gt = model(**batch)\n",
    "            \n",
    "            tmp_eval_loss, logits = outputs[:2]\n",
    "            _, logits_gt = outputs_gt[:2]\n",
    "            \n",
    "            \n",
    "            assert torch.allclose(logits_gt, logits)\n",
    "            \n",
    "            loss = outputs[0]\n",
    "        if preds is None:\n",
    "            preds = logits.detach().cpu().numpy()\n",
    "            out_label_ids = batch['labels'].detach().cpu().numpy()\n",
    "        else:\n",
    "            preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)\n",
    "            out_label_ids = np.append(out_label_ids, batch['labels'].detach().cpu().numpy(), axis=0)\n",
    "        if i % 10 == 0:\n",
    "    #         print(f\"loss: {loss}\")\n",
    "            pass\n",
    "\n",
    "    preds = np.argmax(preds, axis=1)\n",
    "\n",
    "    print(f'accuracy: {simple_accuracy(preds, out_label_ids)}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "id": "b1678f4d",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|                                                                                                                                                                              | 0/13 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dict_items([('input_ids', tensor([[    0,   894,    26,  ...,     1,     1,     1],\n",
      "        [    0, 43600,  1322,  ...,     1,     1,     1],\n",
      "        [    0,   133,  1404,  ...,     1,     1,     1],\n",
      "        ...,\n",
      "        [    0,   133,  5259,  ...,     1,     1,     1],\n",
      "        [    0,  3908,   209,  ...,     1,     1,     1],\n",
      "        [    0, 33553,    21,  ...,     1,     1,     1]])), ('attention_mask', tensor([[1, 1, 1,  ..., 0, 0, 0],\n",
      "        [1, 1, 1,  ..., 0, 0, 0],\n",
      "        [1, 1, 1,  ..., 0, 0, 0],\n",
      "        ...,\n",
      "        [1, 1, 1,  ..., 0, 0, 0],\n",
      "        [1, 1, 1,  ..., 0, 0, 0],\n",
      "        [1, 1, 1,  ..., 0, 0, 0]])), ('labels', tensor([1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1,\n",
      "        1, 1, 0, 1, 1, 1, 0, 1]))])\n",
      "torch.Size([32, 512])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  8%|████████████▊                                                                                                                                                         | 1/13 [00:34<06:48, 34.02s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dict_items([('input_ids', tensor([[    0, 10980,  3921,  ...,     1,     1,     1],\n",
      "        [    0,  3750,   435,  ...,     1,     1,     1],\n",
      "        [    0,   894,   156,  ...,     1,     1,     1],\n",
      "        ...,\n",
      "        [    0, 36035,     9,  ...,     1,     1,     1],\n",
      "        [    0,   113,  2477,  ...,     1,     1,     1],\n",
      "        [    0,  1711,    74,  ...,     1,     1,     1]])), ('attention_mask', tensor([[1, 1, 1,  ..., 0, 0, 0],\n",
      "        [1, 1, 1,  ..., 0, 0, 0],\n",
      "        [1, 1, 1,  ..., 0, 0, 0],\n",
      "        ...,\n",
      "        [1, 1, 1,  ..., 0, 0, 0],\n",
      "        [1, 1, 1,  ..., 0, 0, 0],\n",
      "        [1, 1, 1,  ..., 0, 0, 0]])), ('labels', tensor([1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1,\n",
      "        1, 1, 1, 1, 0, 1, 1, 1]))])\n",
      "torch.Size([32, 512])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 15%|█████████████████████████▌                                                                                                                                            | 2/13 [01:09<06:21, 34.66s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dict_items([('input_ids', tensor([[    0,  4763, 38805,  ...,     1,     1,     1],\n",
      "        [    0,   133,  4614,  ...,     1,     1,     1],\n",
      "        [    0,   347,  2723,  ...,     1,     1,     1],\n",
      "        ...,\n",
      "        [    0, 15248,  6909,  ...,     1,     1,     1],\n",
      "        [    0, 43294,    67,  ...,     1,     1,     1],\n",
      "        [    0, 42609,    42,  ...,     1,     1,     1]])), ('attention_mask', tensor([[1, 1, 1,  ..., 0, 0, 0],\n",
      "        [1, 1, 1,  ..., 0, 0, 0],\n",
      "        [1, 1, 1,  ..., 0, 0, 0],\n",
      "        ...,\n",
      "        [1, 1, 1,  ..., 0, 0, 0],\n",
      "        [1, 1, 1,  ..., 0, 0, 0],\n",
      "        [1, 1, 1,  ..., 0, 0, 0]])), ('labels', tensor([1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0,\n",
      "        1, 1, 0, 1, 1, 0, 1, 1]))])\n",
      "torch.Size([32, 512])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 23%|██████████████████████████████████████▎                                                                                                                               | 3/13 [01:46<05:59, 35.95s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dict_items([('input_ids', tensor([[    0,   133,   690,  ...,     1,     1,     1],\n",
      "        [    0, 28188,    13,  ...,     1,     1,     1],\n",
      "        [    0,  1106,     5,  ...,     1,     1,     1],\n",
      "        ...,\n",
      "        [    0,  9497,   224,  ...,     1,     1,     1],\n",
      "        [    0,  4771, 11680,  ...,     1,     1,     1],\n",
      "        [    0, 28012,   282,  ...,     1,     1,     1]])), ('attention_mask', tensor([[1, 1, 1,  ..., 0, 0, 0],\n",
      "        [1, 1, 1,  ..., 0, 0, 0],\n",
      "        [1, 1, 1,  ..., 0, 0, 0],\n",
      "        ...,\n",
      "        [1, 1, 1,  ..., 0, 0, 0],\n",
      "        [1, 1, 1,  ..., 0, 0, 0],\n",
      "        [1, 1, 1,  ..., 0, 0, 0]])), ('labels', tensor([0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1,\n",
      "        1, 1, 0, 1, 1, 1, 1, 0]))])\n",
      "torch.Size([32, 512])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 31%|███████████████████████████████████████████████████                                                                                                                   | 4/13 [02:23<05:28, 36.50s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dict_items([('input_ids', tensor([[    0,   133,   908,  ...,     1,     1,     1],\n",
      "        [    0,  3084,   138,  ...,     1,     1,     1],\n",
      "        [    0, 32703,    12,  ...,     1,     1,     1],\n",
      "        ...,\n",
      "        [    0,   133, 23781,  ...,     1,     1,     1],\n",
      "        [    0, 30019, 20906,  ...,     1,     1,     1],\n",
      "        [    0, 12582, 40787,  ...,     1,     1,     1]])), ('attention_mask', tensor([[1, 1, 1,  ..., 0, 0, 0],\n",
      "        [1, 1, 1,  ..., 0, 0, 0],\n",
      "        [1, 1, 1,  ..., 0, 0, 0],\n",
      "        ...,\n",
      "        [1, 1, 1,  ..., 0, 0, 0],\n",
      "        [1, 1, 1,  ..., 0, 0, 0],\n",
      "        [1, 1, 1,  ..., 0, 0, 0]])), ('labels', tensor([1, 1, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1,\n",
      "        1, 1, 1, 1, 1, 0, 0, 1]))])\n",
      "torch.Size([32, 512])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 38%|███████████████████████████████████████████████████████████████▊                                                                                                      | 5/13 [03:01<04:54, 36.84s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dict_items([('input_ids', tensor([[    0, 21674,   112,  ...,     1,     1,     1],\n",
      "        [    0, 44888,  1322,  ...,     1,     1,     1],\n",
      "        [    0,  3762,     9,  ...,     1,     1,     1],\n",
      "        ...,\n",
      "        [    0,  5771,     5,  ...,     1,     1,     1],\n",
      "        [    0, 42510,    56,  ...,     1,     1,     1],\n",
      "        [    0,   133, 13387,  ...,     1,     1,     1]])), ('attention_mask', tensor([[1, 1, 1,  ..., 0, 0, 0],\n",
      "        [1, 1, 1,  ..., 0, 0, 0],\n",
      "        [1, 1, 1,  ..., 0, 0, 0],\n",
      "        ...,\n",
      "        [1, 1, 1,  ..., 0, 0, 0],\n",
      "        [1, 1, 1,  ..., 0, 0, 0],\n",
      "        [1, 1, 1,  ..., 0, 0, 0]])), ('labels', tensor([1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1,\n",
      "        0, 1, 0, 1, 0, 1, 1, 0]))])\n",
      "torch.Size([32, 512])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 46%|████████████████████████████████████████████████████████████████████████████▌                                                                                         | 6/13 [03:37<04:16, 36.60s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dict_items([('input_ids', tensor([[    0, 35346,  2993,  ...,     1,     1,     1],\n",
      "        [    0,  1121,   130,  ...,     1,     1,     1],\n",
      "        [    0, 21674,    80,  ...,     1,     1,     1],\n",
      "        ...,\n",
      "        [    0,   133,   796,  ...,     1,     1,     1],\n",
      "        [    0,   113,   407,  ...,     1,     1,     1],\n",
      "        [    0, 37703,   135,  ...,     1,     1,     1]])), ('attention_mask', tensor([[1, 1, 1,  ..., 0, 0, 0],\n",
      "        [1, 1, 1,  ..., 0, 0, 0],\n",
      "        [1, 1, 1,  ..., 0, 0, 0],\n",
      "        ...,\n",
      "        [1, 1, 1,  ..., 0, 0, 0],\n",
      "        [1, 1, 1,  ..., 0, 0, 0],\n",
      "        [1, 1, 1,  ..., 0, 0, 0]])), ('labels', tensor([1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1,\n",
      "        1, 1, 0, 1, 1, 1, 1, 0]))])\n",
      "torch.Size([32, 512])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 54%|█████████████████████████████████████████████████████████████████████████████████████████▍                                                                            | 7/13 [04:12<03:35, 35.95s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dict_items([('input_ids', tensor([[   0,  894,   16,  ...,    1,    1,    1],\n",
      "        [   0,  133,  604,  ...,    1,    1,    1],\n",
      "        [   0,  113,  440,  ...,    1,    1,    1],\n",
      "        ...,\n",
      "        [   0, 4993,  504,  ...,    1,    1,    1],\n",
      "        [   0,  565, 3937,  ...,    1,    1,    1],\n",
      "        [   0,  113,  152,  ...,    1,    1,    1]])), ('attention_mask', tensor([[1, 1, 1,  ..., 0, 0, 0],\n",
      "        [1, 1, 1,  ..., 0, 0, 0],\n",
      "        [1, 1, 1,  ..., 0, 0, 0],\n",
      "        ...,\n",
      "        [1, 1, 1,  ..., 0, 0, 0],\n",
      "        [1, 1, 1,  ..., 0, 0, 0],\n",
      "        [1, 1, 1,  ..., 0, 0, 0]])), ('labels', tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0,\n",
      "        1, 1, 1, 0, 1, 0, 1, 0]))])\n",
      "torch.Size([32, 512])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 62%|██████████████████████████████████████████████████████████████████████████████████████████████████████▏                                                               | 8/13 [04:46<02:57, 35.53s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dict_items([('input_ids', tensor([[    0, 30888,    12,  ...,     1,     1,     1],\n",
      "        [    0, 36310,  2156,  ...,     1,     1,     1],\n",
      "        [    0,  1121,     5,  ...,     1,     1,     1],\n",
      "        ...,\n",
      "        [    0,   113,   166,  ...,     1,     1,     1],\n",
      "        [    0, 10980,     4,  ...,     1,     1,     1],\n",
      "        [    0,   133,   316,  ...,     1,     1,     1]])), ('attention_mask', tensor([[1, 1, 1,  ..., 0, 0, 0],\n",
      "        [1, 1, 1,  ..., 0, 0, 0],\n",
      "        [1, 1, 1,  ..., 0, 0, 0],\n",
      "        ...,\n",
      "        [1, 1, 1,  ..., 0, 0, 0],\n",
      "        [1, 1, 1,  ..., 0, 0, 0],\n",
      "        [1, 1, 1,  ..., 0, 0, 0]])), ('labels', tensor([1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0,\n",
      "        1, 1, 1, 0, 0, 1, 1, 1]))])\n",
      "torch.Size([32, 512])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 69%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉                                                   | 9/13 [05:21<02:20, 35.16s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dict_items([('input_ids', tensor([[    0,  1121,   902,  ...,     1,     1,     1],\n",
      "        [    0,   100,    21,  ...,     1,     1,     1],\n",
      "        [    0,   133,   194,  ...,     1,     1,     1],\n",
      "        ...,\n",
      "        [    0,   113,    85,  ...,     1,     1,     1],\n",
      "        [    0,   717,  6712,  ...,     1,     1,     1],\n",
      "        [    0, 16025,  2614,  ...,     1,     1,     1]])), ('attention_mask', tensor([[1, 1, 1,  ..., 0, 0, 0],\n",
      "        [1, 1, 1,  ..., 0, 0, 0],\n",
      "        [1, 1, 1,  ..., 0, 0, 0],\n",
      "        ...,\n",
      "        [1, 1, 1,  ..., 0, 0, 0],\n",
      "        [1, 1, 1,  ..., 0, 0, 0],\n",
      "        [1, 1, 1,  ..., 0, 0, 0]])), ('labels', tensor([1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1,\n",
      "        1, 0, 1, 0, 1, 1, 0, 0]))])\n",
      "torch.Size([32, 512])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 77%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉                                      | 10/13 [05:55<01:44, 34.97s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dict_items([('input_ids', tensor([[    0,  9497,  3986,  ...,     1,     1,     1],\n",
      "        [    0, 14229,    42,  ...,     1,     1,     1],\n",
      "        [    0,   133,   806,  ...,     1,     1,     1],\n",
      "        ...,\n",
      "        [    0,   250,   320,  ...,     1,     1,     1],\n",
      "        [    0, 11329,   219,  ...,     1,     1,     1],\n",
      "        [    0, 40145,  5369,  ...,     1,     1,     1]])), ('attention_mask', tensor([[1, 1, 1,  ..., 0, 0, 0],\n",
      "        [1, 1, 1,  ..., 0, 0, 0],\n",
      "        [1, 1, 1,  ..., 0, 0, 0],\n",
      "        ...,\n",
      "        [1, 1, 1,  ..., 0, 0, 0],\n",
      "        [1, 1, 1,  ..., 0, 0, 0],\n",
      "        [1, 1, 1,  ..., 0, 0, 0]])), ('labels', tensor([0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0,\n",
      "        1, 1, 1, 1, 0, 1, 0, 1]))])\n",
      "torch.Size([32, 512])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 85%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                         | 11/13 [06:29<01:09, 34.68s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dict_items([('input_ids', tensor([[    0,   133, 13640,  ...,     1,     1,     1],\n",
      "        [    0,   243,    67,  ...,     1,     1,     1],\n",
      "        [    0,   243,  1276,  ...,     1,     1,     1],\n",
      "        ...,\n",
      "        [    0, 35779,    32,  ...,     1,     1,     1],\n",
      "        [    0,  8800,   388,  ...,     1,     1,     1],\n",
      "        [    0,   133,  4079,  ...,     1,     1,     1]])), ('attention_mask', tensor([[1, 1, 1,  ..., 0, 0, 0],\n",
      "        [1, 1, 1,  ..., 0, 0, 0],\n",
      "        [1, 1, 1,  ..., 0, 0, 0],\n",
      "        ...,\n",
      "        [1, 1, 1,  ..., 0, 0, 0],\n",
      "        [1, 1, 1,  ..., 0, 0, 0],\n",
      "        [1, 1, 1,  ..., 0, 0, 0]])), ('labels', tensor([0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0,\n",
      "        0, 0, 1, 1, 1, 1, 0, 1]))])\n",
      "torch.Size([32, 512])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 92%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎            | 12/13 [07:06<00:35, 35.37s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dict_items([('input_ids', tensor([[    0,   133,  1576,  ...,     1,     1,     1],\n",
      "        [    0,  1121,     5,  ...,     1,     1,     1],\n",
      "        [    0, 42047, 12637,  ...,     1,     1,     1],\n",
      "        ...,\n",
      "        [    0,  5625,    11,  ...,     1,     1,     1],\n",
      "        [    0,   133,   208,  ...,     1,     1,     1],\n",
      "        [    0,  9089,    12,  ...,     1,     1,     1]])), ('attention_mask', tensor([[1, 1, 1,  ..., 0, 0, 0],\n",
      "        [1, 1, 1,  ..., 0, 0, 0],\n",
      "        [1, 1, 1,  ..., 0, 0, 0],\n",
      "        ...,\n",
      "        [1, 1, 1,  ..., 0, 0, 0],\n",
      "        [1, 1, 1,  ..., 0, 0, 0],\n",
      "        [1, 1, 1,  ..., 0, 0, 0]])), ('labels', tensor([1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1]))])\n",
      "torch.Size([24, 512])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [09:35<00:00, 44.29s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "accuracy: 0.9117647058823529\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "/Users/oliver/miniforge3/envs/trans-fat/lib/python3.9/site-packages/transformers/data/metrics/__init__.py:36: FutureWarning: This metric will be removed from the library soon, metrics should be handled with the 🤗 Datasets library. You can have a look at this example script for pointers: https://github.com/huggingface/transformers/blob/master/examples/pytorch/text-classification/run_glue.py\n",
      "  warnings.warn(DEPRECATION_WARNING, FutureWarning)\n"
     ]
    }
   ],
   "source": [
    "eval_model(model, dataloader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7efc44a8",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "trans-fat",
   "language": "python",
   "name": "trans-fat"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
